{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "702b1d59",
   "metadata": {},
   "source": [
    "## requirements\n",
    "### mindspore==2.3.1\n",
    "### mindnlp==0.4.1\n",
    "\n",
    "## Data\n",
    "Download the data from this [link](https://gluebenchmark.com/tasks). There will be a main zip file download option at the right side of the page. Extract the contents of the zip file and place them in data/SST-2/"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "575c9caa",
   "metadata": {},
   "source": [
    "导入所需库"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "aca6d5f7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model cost 1.322 seconds.\n",
      "Prefix dict has been built successfully.\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "import pandas as pd\n",
    "\n",
    "import mindspore\n",
    "from mindspore import save_checkpoint\n",
    "\n",
    "from mindnlp.core.nn import BCEWithLogitsLoss, Tensor\n",
    "from mindnlp.core.optim import Adam\n",
    "from mindnlp.transformers import BertModel, BertTokenizer\n",
    "from mindnlp.core import nn, value_and_grad\n",
    "from mindnlp.core.ops import sigmoid"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8c93751",
   "metadata": {},
   "source": [
    "自定义模型类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50513139",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SentimentClassifier(nn.Module):\n",
    "    def __init__(self, base_model_name_or_path = 'bert-base-uncased', freeze_bert = True):\n",
    "        super().__init__()\n",
    "        #Instantiating BERT model object \n",
    "        self.bert_layer = BertModel.from_pretrained(base_model_name_or_path)\n",
    "\n",
    "        #Freeze bert layers\n",
    "        if freeze_bert:\n",
    "            for p in self.bert_layer.parameters():\n",
    "                p.requires_grad = False\n",
    "\n",
    "        #Classification layer\n",
    "        self.cls_layer = nn.Linear(768, 1)\n",
    "\n",
    "    def forward(self, seq, attn_masks):\n",
    "        '''\n",
    "        Inputs:\n",
    "            -seq : Tensor of shape [B, T] containing token ids of sequences\n",
    "            -attn_masks : Tensor of shape [B, T] containing attention masks to be used to avoid contibution of PAD tokens\n",
    "        '''\n",
    "\n",
    "        #Feeding the input to BERT model\n",
    "        last_hs = self.bert_layer(seq, attention_mask = attn_masks).last_hidden_state\n",
    "\n",
    "        #Obtaining the representation of [CLS] head\n",
    "        cls_rep = last_hs[:, 0]\n",
    "\n",
    "        #Feeding cls_rep to the classifier layer\n",
    "        logits = self.cls_layer(cls_rep)\n",
    "\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e9591b9",
   "metadata": {},
   "source": [
    "自定义数据加载"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f55da01",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SSTDataset():\n",
    "    def __init__(self, base_model_name_or_path, filename, maxlen):\n",
    "\n",
    "        #Store the contents of the file in a pandas dataframe\n",
    "        self.df = pd.read_csv(filename, delimiter = '\\t')\n",
    "\n",
    "        #Initialize the BERT tokenizer\n",
    "        self.tokenizer = BertTokenizer.from_pretrained(base_model_name_or_path)\n",
    "\n",
    "        self.maxlen = maxlen\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.df)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "\n",
    "        #Selecting the sentence and label at the specified index in the data frame\n",
    "        sentence = self.df.loc[index, 'sentence']\n",
    "        label = self.df.loc[index, 'label']\n",
    "\n",
    "        #Preprocessing the text to be suitable for BERT\n",
    "        tokens = self.tokenizer.tokenize(sentence) #Tokenize the sentence\n",
    "        tokens = ['[CLS]'] + tokens + ['[SEP]'] #Insering the CLS and SEP token in the beginning and end of the sentence\n",
    "        if len(tokens) < self.maxlen:\n",
    "            tokens = tokens + ['[PAD]' for _ in range(self.maxlen - len(tokens))] #Padding sentences\n",
    "        else:\n",
    "            tokens = tokens[:self.maxlen-1] + ['[SEP]'] #Prunning the list to be of specified max length\n",
    "\n",
    "        tokens_ids = self.tokenizer.convert_tokens_to_ids(tokens) #Obtaining the indices of the tokens in the BERT Vocabulary\n",
    "\n",
    "        return tokens_ids, label\n",
    "\n",
    "def get_loader(dataset, batchsize, shuffle=True, num_workers=1, drop_remainder=True):\n",
    "    data_loader = mindspore.dataset.GeneratorDataset(source=dataset,\n",
    "                                      column_names=['tokens_ids', 'label'],\n",
    "                                      shuffle=shuffle,\n",
    "                                      num_parallel_workers=num_workers\n",
    "                                      )\n",
    "    data_loader = data_loader.batch(batch_size=batchsize, \n",
    "                                    drop_remainder=drop_remainder,\n",
    "                                    )\n",
    "    return data_loader.create_dict_iterator()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "997ce3c5",
   "metadata": {},
   "source": [
    "自定义Trainer类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18fd6134",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_accuracy_from_logits(logits, labels):\n",
    "    probs = sigmoid(logits.unsqueeze(-1))\n",
    "    soft_probs = (probs > 0.5).long()\n",
    "    acc = (soft_probs.squeeze() == labels).float().mean()\n",
    "    return acc\n",
    "\n",
    "def evaluate(net, criterion, dataloader):\n",
    "    mean_acc, mean_loss = 0, 0\n",
    "    count = 0\n",
    "\n",
    "    for data in dataloader:\n",
    "        tokens_ids = data['tokens_ids']\n",
    "        attn_mask = (tokens_ids != 0).long()\n",
    "        label = data['label']\n",
    "        logits = net(tokens_ids, attn_mask)\n",
    "        mean_loss += criterion(logits.squeeze(-1), label.astype('float32')).asnumpy()\n",
    "        mean_acc += get_accuracy_from_logits(logits, label)\n",
    "        count += 1\n",
    "\n",
    "    return mean_acc / count, mean_loss / count\n",
    "\n",
    "class Trainer:\n",
    "\n",
    "    def __init__(self, net, criterion, optimizer, args,\n",
    "                 train_dataset, eval_dataset=None\n",
    "                 ):\n",
    "        self.net = net\n",
    "        self.criterion = criterion\n",
    "        self.opt = optimizer\n",
    "        self.args = args\n",
    "        self.train_dataset = train_dataset\n",
    "        self.weights = self.net.trainable_params()\n",
    "        self.value_and_grad = value_and_grad(fn=self.forward_fn, params_or_argnums=self.weights)\n",
    "        self.run_eval = eval_dataset is not None\n",
    "        if self.run_eval:\n",
    "            self.eval_dataset = eval_dataset\n",
    "        self.logits = None\n",
    "\n",
    "    def forward_fn(self, tokens_ids_tensor, attn_mask, label):\n",
    "        logits = self.net(tokens_ids_tensor, attn_mask)\n",
    "        self.logits = logits\n",
    "        loss = self.criterion(logits.squeeze(-1), label)\n",
    "        return loss\n",
    "\n",
    "    def train_single(self, tokens_ids_tensor, attn_mask, label):\n",
    "        self.opt.zero_grad()\n",
    "        loss = self.value_and_grad(tokens_ids_tensor, attn_mask, label)\n",
    "        self.opt.step()\n",
    "        return loss\n",
    "\n",
    "    def train(self, epochs):\n",
    "        best_acc = 0\n",
    "        for epoch in range(0, epochs):\n",
    "            self.net.set_train(True)\n",
    "            for i, data in enumerate(self.train_dataset):\n",
    "                tokens_ids = data['tokens_ids']\n",
    "                attn_mask = Tensor((tokens_ids != 0).long())\n",
    "                label = data['label']\n",
    "\n",
    "                loss = self.train_single(tokens_ids, attn_mask, label.astype('float32'))\n",
    "\n",
    "                if i % self.args.print_every == 0:\n",
    "                    acc = get_accuracy_from_logits(self.logits, label)\n",
    "                    print(\"Iteration {} of epoch {} complete. Loss : {} Accuracy : {}\".format(i, epoch, loss.asnumpy(), acc))\n",
    "\n",
    "            if self.run_eval:\n",
    "                self.net.set_train(False)\n",
    "                val_acc, val_loss = evaluate(self.net, self.criterion, self.eval_dataset)\n",
    "                print(\"Epoch {} complete! Validation Accuracy : {}, Validation Loss : {}\".format(epoch, val_acc, val_loss))\n",
    "                if val_acc > best_acc:\n",
    "                    print(\"Best validation accuracy improved from {} to {}\".format(best_acc, val_acc))\n",
    "                    best_acc = val_acc\n",
    "                    if self.args.save_path is not None:\n",
    "                        print(\"saving model...\")\n",
    "                        save_checkpoint(self.net, self.args.save_path + 'best_model.ckpt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c404f16",
   "metadata": {},
   "source": [
    "主函数入口，完整训练流程"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccb0012b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def main(args):\n",
    "    #Instantiating the classifier model\n",
    "    print(\"Building model! (This might take time if you are running this for first time)\")\n",
    "    st = time.time()\n",
    "    mindspore.set_context(device_target=args.device_target, device_id=args.device_id)\n",
    "    net = SentimentClassifier(args.base_model_name_or_path, args.freeze_bert)\n",
    "    print(\"Done in {} seconds\".format(time.time() - st))\n",
    "\n",
    "    print(\"Creating criterion and optimizer objects\")\n",
    "    st = time.time()\n",
    "    criterion = BCEWithLogitsLoss()\n",
    "    opti = Adam(net.trainable_params(), lr=args.lr)\n",
    "    print(\"Done in {} seconds\".format(time.time() - st))\n",
    "\n",
    "    #Creating dataloaders\n",
    "    print(\"Creating train and val dataloaders\")\n",
    "    st = time.time()\n",
    "    train_set = SSTDataset(args.base_model_name_or_path, filename = args.dataset_name_or_path + '/train.tsv', maxlen = args.maxlen)\n",
    "    val_set = SSTDataset(args.base_model_name_or_path, filename = args.dataset_name_or_path + '/dev.tsv', maxlen = args.maxlen)\n",
    "\n",
    "    train_loader = get_loader(train_set, batchsize=args.batch_size)\n",
    "    val_loader = get_loader(val_set, batchsize=args.batch_size, drop_remainder=False)\n",
    "    print(\"Done in {} seconds\".format(time.time() - st))\n",
    "\n",
    "    print(\"Let the training begin\")\n",
    "    st = time.time()\n",
    "    trainer = Trainer(net=net, criterion=criterion, optimizer=opti, args=args, train_dataset=train_loader, eval_dataset=val_loader)\n",
    "    trainer.train(epochs=args.max_eps)\n",
    "    print(\"Done in {} seconds\".format(time.time() - st))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f686b3cc",
   "metadata": {},
   "source": [
    "设置训练参数，开始训练(冻结BERT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "58d49a2b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Building model! (This might take time if you are running this for first time)\n",
      "[MS_ALLOC_CONF]Runtime config:  enable_vmm:True  vmm_align_size:2MB\n",
      "Done in 11.163130283355713 seconds\n",
      "Creating criterion and optimizer objects\n",
      "Done in 0.0013878345489501953 seconds\n",
      "Creating train and val dataloaders\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindnlp/transformers/tokenization_utils_base.py:1526: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted, and will be then set to `False` by default. \n",
      "  warnings.warn(\n",
      "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindnlp/transformers/tokenization_utils_base.py:1526: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted, and will be then set to `False` by default. \n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done in 3.1491587162017822 seconds\n",
      "Let the training begin\n",
      "Iteration 0 of epoch 0 complete. Loss : 0.72922682762146 Accuracy : 0.40625\n",
      "Iteration 500 of epoch 0 complete. Loss : 0.6815673112869263 Accuracy : 0.5\n",
      "Iteration 1000 of epoch 0 complete. Loss : 0.6588367223739624 Accuracy : 0.5625\n",
      "Iteration 1500 of epoch 0 complete. Loss : 0.633255124092102 Accuracy : 0.59375\n",
      "Iteration 2000 of epoch 0 complete. Loss : 0.6255167722702026 Accuracy : 0.6875\n",
      "Epoch 0 complete! Validation Accuracy : 0.70870537, Validation Loss : 0.606412410736084\n",
      "Best validation accuracy improved from 0 to 0.70870537\n",
      "Iteration 0 of epoch 1 complete. Loss : 0.6649124622344971 Accuracy : 0.5625\n",
      "Iteration 500 of epoch 1 complete. Loss : 0.5791333913803101 Accuracy : 0.78125\n",
      "Iteration 1000 of epoch 1 complete. Loss : 0.5701107978820801 Accuracy : 0.875\n",
      "Iteration 1500 of epoch 1 complete. Loss : 0.4862162470817566 Accuracy : 0.90625\n",
      "Iteration 2000 of epoch 1 complete. Loss : 0.6048108339309692 Accuracy : 0.6875\n",
      "Epoch 1 complete! Validation Accuracy : 0.7901786, Validation Loss : 0.5464124296392713\n",
      "Best validation accuracy improved from 0.70870537 to 0.7901786\n",
      "Iteration 0 of epoch 2 complete. Loss : 0.5632082223892212 Accuracy : 0.8125\n",
      "Iteration 500 of epoch 2 complete. Loss : 0.5999962091445923 Accuracy : 0.6875\n",
      "Iteration 1000 of epoch 2 complete. Loss : 0.5536276698112488 Accuracy : 0.75\n",
      "Iteration 1500 of epoch 2 complete. Loss : 0.5445767641067505 Accuracy : 0.8125\n",
      "Iteration 2000 of epoch 2 complete. Loss : 0.5306562185287476 Accuracy : 0.8125\n",
      "Epoch 2 complete! Validation Accuracy : 0.8125, Validation Loss : 0.502809716122491\n",
      "Best validation accuracy improved from 0.7901786 to 0.8125\n",
      "Iteration 0 of epoch 3 complete. Loss : 0.545007050037384 Accuracy : 0.84375\n",
      "Iteration 500 of epoch 3 complete. Loss : 0.5180479288101196 Accuracy : 0.84375\n",
      "Iteration 1000 of epoch 3 complete. Loss : 0.5651894807815552 Accuracy : 0.71875\n",
      "Iteration 1500 of epoch 3 complete. Loss : 0.4774022102355957 Accuracy : 0.84375\n",
      "Iteration 2000 of epoch 3 complete. Loss : 0.46579915285110474 Accuracy : 0.875\n",
      "Epoch 3 complete! Validation Accuracy : 0.8058036, Validation Loss : 0.48010874752487454\n",
      "Iteration 0 of epoch 4 complete. Loss : 0.5770348310470581 Accuracy : 0.71875\n",
      "Iteration 500 of epoch 4 complete. Loss : 0.46957165002822876 Accuracy : 0.8125\n",
      "Iteration 1000 of epoch 4 complete. Loss : 0.530113935470581 Accuracy : 0.71875\n",
      "Iteration 1500 of epoch 4 complete. Loss : 0.4496270418167114 Accuracy : 0.875\n",
      "Iteration 2000 of epoch 4 complete. Loss : 0.5048515796661377 Accuracy : 0.8125\n",
      "Epoch 4 complete! Validation Accuracy : 0.8125, Validation Loss : 0.46174625733069014\n",
      "Done in 1292.1033997535706 seconds\n"
     ]
    }
   ],
   "source": [
    "from types import SimpleNamespace\n",
    "\n",
    "args = SimpleNamespace()\n",
    "args.device_target = 'Ascend'\n",
    "args.device_id = 0\n",
    "args.base_model_name_or_path = 'bert-base-uncased'\n",
    "args.dataset_name_or_path = './data/SST-2'\n",
    "args.freeze_bert = True\n",
    "args.maxlen = 25\n",
    "args.batch_size = 32\n",
    "args.lr = 2e-5\n",
    "args.print_every = 500\n",
    "args.max_eps = 5\n",
    "args.save_path = None\n",
    "\n",
    "main(args)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d69446d",
   "metadata": {},
   "source": [
    "设置训练参数，开始训练(不冻结BERT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1a6c8f11",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] ME(1713:281473064419344,MainProcess):2024-12-28-13:10:29.874.179 [mindspore/context.py:1208] For 'context.set_context' in Ascend backend, the backend is already initialized, please set it before the definition of any Tensor and Parameter, and the instantiation and execution of any operation and net, otherwise the settings may not take effect. \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Building model! (This might take time if you are running this for first time)\n",
      "Done in 1.576096534729004 seconds\n",
      "Creating criterion and optimizer objects\n",
      "Done in 0.0017273426055908203 seconds\n",
      "Creating train and val dataloaders\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindnlp/transformers/tokenization_utils_base.py:1526: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted, and will be then set to `False` by default. \n",
      "  warnings.warn(\n",
      "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindnlp/transformers/tokenization_utils_base.py:1526: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted, and will be then set to `False` by default. \n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done in 2.324294090270996 seconds\n",
      "Let the training begin\n",
      "Iteration 0 of epoch 0 complete. Loss : 0.7102465629577637 Accuracy : 0.5625\n",
      "Iteration 500 of epoch 0 complete. Loss : 0.16873285174369812 Accuracy : 0.90625\n",
      "Iteration 1000 of epoch 0 complete. Loss : 0.08603419363498688 Accuracy : 0.96875\n",
      "Iteration 1500 of epoch 0 complete. Loss : 0.10500945895910263 Accuracy : 0.9375\n",
      "Iteration 2000 of epoch 0 complete. Loss : 0.19792640209197998 Accuracy : 0.90625\n",
      "Epoch 0 complete! Validation Accuracy : 0.8984375, Validation Loss : 0.28035964673784164\n",
      "Best validation accuracy improved from 0 to 0.8984375\n",
      "Iteration 0 of epoch 1 complete. Loss : 0.04876020550727844 Accuracy : 1.0\n",
      "Iteration 500 of epoch 1 complete. Loss : 0.18226751685142517 Accuracy : 0.9375\n",
      "Iteration 1000 of epoch 1 complete. Loss : 0.3231828808784485 Accuracy : 0.90625\n",
      "Iteration 1500 of epoch 1 complete. Loss : 0.2185860425233841 Accuracy : 0.90625\n",
      "Iteration 2000 of epoch 1 complete. Loss : 0.05135542154312134 Accuracy : 1.0\n",
      "Epoch 1 complete! Validation Accuracy : 0.90066963, Validation Loss : 0.3017509060778788\n",
      "Best validation accuracy improved from 0.8984375 to 0.90066963\n",
      "Iteration 0 of epoch 2 complete. Loss : 0.0862235352396965 Accuracy : 0.9375\n",
      "Iteration 500 of epoch 2 complete. Loss : 0.23206503689289093 Accuracy : 0.9375\n",
      "Iteration 1000 of epoch 2 complete. Loss : 0.10999009013175964 Accuracy : 0.9375\n",
      "Iteration 1500 of epoch 2 complete. Loss : 0.06986917555332184 Accuracy : 0.96875\n",
      "Iteration 2000 of epoch 2 complete. Loss : 0.018712937831878662 Accuracy : 1.0\n",
      "Epoch 2 complete! Validation Accuracy : 0.8816964, Validation Loss : 0.3697358641241278\n",
      "Iteration 0 of epoch 3 complete. Loss : 0.12822596728801727 Accuracy : 0.9375\n",
      "Iteration 500 of epoch 3 complete. Loss : 0.012879552319645882 Accuracy : 1.0\n",
      "Iteration 1000 of epoch 3 complete. Loss : 0.05736008659005165 Accuracy : 0.96875\n",
      "Iteration 1500 of epoch 3 complete. Loss : 0.09108293056488037 Accuracy : 0.96875\n",
      "Iteration 2000 of epoch 3 complete. Loss : 0.08910780400037766 Accuracy : 0.96875\n",
      "Epoch 3 complete! Validation Accuracy : 0.87834823, Validation Loss : 0.39812061549829586\n",
      "Iteration 0 of epoch 4 complete. Loss : 0.02146654762327671 Accuracy : 1.0\n",
      "Iteration 500 of epoch 4 complete. Loss : 0.12440156191587448 Accuracy : 0.96875\n",
      "Iteration 1000 of epoch 4 complete. Loss : 0.005717694293707609 Accuracy : 1.0\n",
      "Iteration 1500 of epoch 4 complete. Loss : 0.09537044167518616 Accuracy : 0.96875\n",
      "Iteration 2000 of epoch 4 complete. Loss : 0.011304444633424282 Accuracy : 1.0\n",
      "Epoch 4 complete! Validation Accuracy : 0.88616073, Validation Loss : 0.4023839123547077\n",
      "Done in 3904.605708837509 seconds\n"
     ]
    }
   ],
   "source": [
    "from types import SimpleNamespace\n",
    "\n",
    "args = SimpleNamespace()\n",
    "args.device_target = 'Ascend'\n",
    "args.device_id = 0\n",
    "args.base_model_name_or_path = 'bert-base-uncased'\n",
    "args.dataset_name_or_path = './data/SST-2'\n",
    "args.freeze_bert = False\n",
    "args.maxlen = 25\n",
    "args.batch_size = 32\n",
    "args.lr = 2e-5\n",
    "args.print_every = 500\n",
    "args.max_eps = 5\n",
    "args.save_path = None\n",
    "\n",
    "main(args)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
